{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d9bc39df",
   "metadata": {
    "origin_pos": 0
   },
   "source": [
    "# Word Similarity and Analogy\n",
    ":label:`sec_synonyms`\n",
    "\n",
    "In :numref:`sec_word2vec_pretraining`, \n",
    "we trained a word2vec model on a small dataset, \n",
    "and applied it\n",
    "to find semantically similar words \n",
    "for an input word.\n",
    "In practice,\n",
    "word vectors that are pretrained\n",
    "on large corpora can be\n",
    "applied to downstream\n",
    "natural language processing tasks,\n",
    "which will be covered later\n",
    "in :numref:`chap_nlp_app`.\n",
    "To demonstrate \n",
    "semantics of pretrained word vectors\n",
    "from large corpora in a straightforward way,\n",
    "let's apply them\n",
    "in the word similarity and analogy tasks.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "7c36748f",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:39:58.110484Z",
     "iopub.status.busy": "2023-08-18T19:39:58.109802Z",
     "iopub.status.idle": "2023-08-18T19:40:01.601183Z",
     "shell.execute_reply": "2023-08-18T19:40:01.599880Z"
    },
    "origin_pos": 2,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "from torch import nn\n",
    "from d2l import torch as d2l"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "58dea0ce",
   "metadata": {
    "origin_pos": 3
   },
   "source": [
    "## Loading Pretrained Word Vectors\n",
    "\n",
    "Below lists pretrained GloVe embeddings of dimension 50, 100, and 300,\n",
    "which can be downloaded from the [GloVe website](https://nlp.stanford.edu/projects/glove/).\n",
    "The pretrained fastText embeddings are available in multiple languages.\n",
    "Here we consider one English version (300-dimensional \"wiki.en\") that can be downloaded from the\n",
    "[fastText website](https://fasttext.cc/).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "740b8826",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:40:01.605791Z",
     "iopub.status.busy": "2023-08-18T19:40:01.605008Z",
     "iopub.status.idle": "2023-08-18T19:40:01.610925Z",
     "shell.execute_reply": "2023-08-18T19:40:01.609999Z"
    },
    "origin_pos": 4,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "#@save\n",
    "d2l.DATA_HUB['glove.6b.50d'] = (d2l.DATA_URL + 'glove.6B.50d.zip',\n",
    "                                '0b8703943ccdb6eb788e6f091b8946e82231bc4d')\n",
    "\n",
    "#@save\n",
    "d2l.DATA_HUB['glove.6b.100d'] = (d2l.DATA_URL + 'glove.6B.100d.zip',\n",
    "                                 'cd43bfb07e44e6f27cbcc7bc9ae3d80284fdaf5a')\n",
    "\n",
    "#@save\n",
    "d2l.DATA_HUB['glove.42b.300d'] = (d2l.DATA_URL + 'glove.42B.300d.zip',\n",
    "                                  'b5116e234e9eb9076672cfeabf5469f3eec904fa')\n",
    "\n",
    "#@save\n",
    "d2l.DATA_HUB['wiki.en'] = (d2l.DATA_URL + 'wiki.en.zip',\n",
    "                           'c1816da3821ae9f43899be655002f6c723e91b88')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c6e66b37",
   "metadata": {
    "origin_pos": 5
   },
   "source": [
    "To load these pretrained GloVe and fastText embeddings, we define the following `TokenEmbedding` class.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "45117bfd",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:40:01.615965Z",
     "iopub.status.busy": "2023-08-18T19:40:01.614818Z",
     "iopub.status.idle": "2023-08-18T19:40:01.625449Z",
     "shell.execute_reply": "2023-08-18T19:40:01.624404Z"
    },
    "origin_pos": 6,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "#@save\n",
    "class TokenEmbedding:\n",
    "    \"\"\"Token Embedding.\"\"\"\n",
    "    def __init__(self, embedding_name):\n",
    "        self.idx_to_token, self.idx_to_vec = self._load_embedding(\n",
    "            embedding_name)\n",
    "        self.unknown_idx = 0\n",
    "        self.token_to_idx = {token: idx for idx, token in\n",
    "                             enumerate(self.idx_to_token)}\n",
    "\n",
    "    def _load_embedding(self, embedding_name):\n",
    "        idx_to_token, idx_to_vec = ['<unk>'], []\n",
    "        data_dir = d2l.download_extract(embedding_name)\n",
    "        # GloVe website: https://nlp.stanford.edu/projects/glove/\n",
    "        # fastText website: https://fasttext.cc/\n",
    "        with open(os.path.join(data_dir, 'vec.txt'), 'r') as f:\n",
    "            for line in f:\n",
    "                elems = line.rstrip().split(' ')\n",
    "                token, elems = elems[0], [float(elem) for elem in elems[1:]]\n",
    "                # Skip header information, such as the top row in fastText\n",
    "                if len(elems) > 1:\n",
    "                    idx_to_token.append(token)\n",
    "                    idx_to_vec.append(elems)\n",
    "        idx_to_vec = [[0] * len(idx_to_vec[0])] + idx_to_vec\n",
    "        return idx_to_token, torch.tensor(idx_to_vec)\n",
    "\n",
    "    def __getitem__(self, tokens):\n",
    "        indices = [self.token_to_idx.get(token, self.unknown_idx)\n",
    "                   for token in tokens]\n",
    "        vecs = self.idx_to_vec[torch.tensor(indices)]\n",
    "        return vecs\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.idx_to_token)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "927d092b",
   "metadata": {
    "origin_pos": 7
   },
   "source": [
    "Below we load the\n",
    "50-dimensional GloVe embeddings\n",
    "(pretrained on a Wikipedia subset).\n",
    "When creating the `TokenEmbedding` instance,\n",
    "the specified embedding file has to be downloaded if it\n",
    "was not yet.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "4487556e",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:40:01.630355Z",
     "iopub.status.busy": "2023-08-18T19:40:01.630031Z",
     "iopub.status.idle": "2023-08-18T19:40:18.097173Z",
     "shell.execute_reply": "2023-08-18T19:40:18.096139Z"
    },
    "origin_pos": 8,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading ../data/glove.6B.50d.zip from http://d2l-data.s3-accelerate.amazonaws.com/glove.6B.50d.zip...\n"
     ]
    }
   ],
   "source": [
    "glove_6b50d = TokenEmbedding('glove.6b.50d')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "99425a33",
   "metadata": {
    "origin_pos": 9
   },
   "source": [
    "Output the vocabulary size. The vocabulary contains 400000 words (tokens) and a special unknown token.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "ff0fe3b2",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:40:18.101681Z",
     "iopub.status.busy": "2023-08-18T19:40:18.101067Z",
     "iopub.status.idle": "2023-08-18T19:40:18.107937Z",
     "shell.execute_reply": "2023-08-18T19:40:18.107092Z"
    },
    "origin_pos": 10,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "400001"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(glove_6b50d)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c5839f20",
   "metadata": {
    "origin_pos": 11
   },
   "source": [
    "We can get the index of a word in the vocabulary, and vice versa.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "bee6bad4",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:40:18.111818Z",
     "iopub.status.busy": "2023-08-18T19:40:18.111076Z",
     "iopub.status.idle": "2023-08-18T19:40:18.116940Z",
     "shell.execute_reply": "2023-08-18T19:40:18.116073Z"
    },
    "origin_pos": 12,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(3367, 'beautiful')"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "glove_6b50d.token_to_idx['beautiful'], glove_6b50d.idx_to_token[3367]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "261f53bf",
   "metadata": {
    "origin_pos": 13
   },
   "source": [
    "## Applying Pretrained Word Vectors\n",
    "\n",
    "Using the loaded GloVe vectors,\n",
    "we will demonstrate their semantics\n",
    "by applying them\n",
    "in the following word similarity and analogy tasks.\n",
    "\n",
    "\n",
    "### Word Similarity\n",
    "\n",
    "Similar to :numref:`subsec_apply-word-embed`,\n",
    "in order to find semantically similar words\n",
    "for an input word\n",
    "based on cosine similarities between\n",
    "word vectors,\n",
    "we implement the following `knn`\n",
    "($k$-nearest neighbors) function.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "f9a7b445",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:40:18.120761Z",
     "iopub.status.busy": "2023-08-18T19:40:18.120069Z",
     "iopub.status.idle": "2023-08-18T19:40:18.125858Z",
     "shell.execute_reply": "2023-08-18T19:40:18.124783Z"
    },
    "origin_pos": 15,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "def knn(W, x, k):\n",
    "    # Add 1e-9 for numerical stability\n",
    "    cos = torch.mv(W, x.reshape(-1,)) / (\n",
    "        torch.sqrt(torch.sum(W * W, axis=1) + 1e-9) *\n",
    "        torch.sqrt((x * x).sum()))\n",
    "    _, topk = torch.topk(cos, k=k)\n",
    "    return topk, [cos[int(i)] for i in topk]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "129060b5",
   "metadata": {
    "origin_pos": 16
   },
   "source": [
    "Then, we \n",
    "search for similar words\n",
    "using the pretrained word vectors \n",
    "from the `TokenEmbedding` instance `embed`.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "89565196",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:40:18.129541Z",
     "iopub.status.busy": "2023-08-18T19:40:18.128977Z",
     "iopub.status.idle": "2023-08-18T19:40:18.134362Z",
     "shell.execute_reply": "2023-08-18T19:40:18.133532Z"
    },
    "origin_pos": 17,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "def get_similar_tokens(query_token, k, embed):\n",
    "    topk, cos = knn(embed.idx_to_vec, embed[[query_token]], k + 1)\n",
    "    for i, c in zip(topk[1:], cos[1:]):  # Exclude the input word\n",
    "        print(f'cosine sim={float(c):.3f}: {embed.idx_to_token[int(i)]}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "87927a0a",
   "metadata": {
    "origin_pos": 18
   },
   "source": [
    "The vocabulary of the pretrained word vectors\n",
    "in `glove_6b50d` contains 400000 words and a special unknown token. \n",
    "Excluding the input word and unknown token,\n",
    "among this vocabulary\n",
    "let's find \n",
    "three most semantically similar words\n",
    "to word \"chip\".\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "f6006a6c",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:40:18.137799Z",
     "iopub.status.busy": "2023-08-18T19:40:18.137504Z",
     "iopub.status.idle": "2023-08-18T19:40:18.192497Z",
     "shell.execute_reply": "2023-08-18T19:40:18.191240Z"
    },
    "origin_pos": 19,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cosine sim=0.856: chips\n",
      "cosine sim=0.749: intel\n",
      "cosine sim=0.749: electronics\n"
     ]
    }
   ],
   "source": [
    "get_similar_tokens('chip', 3, glove_6b50d)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5da51ab7",
   "metadata": {
    "origin_pos": 20
   },
   "source": [
    "Below outputs similar words\n",
    "to \"baby\" and \"beautiful\".\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "eefbb4fa",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:40:18.196753Z",
     "iopub.status.busy": "2023-08-18T19:40:18.196073Z",
     "iopub.status.idle": "2023-08-18T19:40:18.222160Z",
     "shell.execute_reply": "2023-08-18T19:40:18.221182Z"
    },
    "origin_pos": 21,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cosine sim=0.839: babies\n",
      "cosine sim=0.800: boy\n",
      "cosine sim=0.792: girl\n"
     ]
    }
   ],
   "source": [
    "get_similar_tokens('baby', 3, glove_6b50d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "b69e3fad",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:40:18.225735Z",
     "iopub.status.busy": "2023-08-18T19:40:18.225435Z",
     "iopub.status.idle": "2023-08-18T19:40:18.247512Z",
     "shell.execute_reply": "2023-08-18T19:40:18.246311Z"
    },
    "origin_pos": 22,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cosine sim=0.921: lovely\n",
      "cosine sim=0.893: gorgeous\n",
      "cosine sim=0.830: wonderful\n"
     ]
    }
   ],
   "source": [
    "get_similar_tokens('beautiful', 3, glove_6b50d)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f50103d8",
   "metadata": {
    "origin_pos": 23
   },
   "source": [
    "### Word Analogy\n",
    "\n",
    "Besides finding similar words,\n",
    "we can also apply word vectors\n",
    "to word analogy tasks.\n",
    "For example,\n",
    "“man”:“woman”::“son”:“daughter”\n",
    "is the form of a word analogy:\n",
    "“man” is to “woman” as “son” is to “daughter”.\n",
    "Specifically,\n",
    "the word analogy completion task\n",
    "can be defined as:\n",
    "for a word analogy \n",
    "$a : b :: c : d$, given the first three words $a$, $b$ and $c$, find $d$. \n",
    "Denote the vector of word $w$ by $\\textrm{vec}(w)$. \n",
    "To complete the analogy,\n",
    "we will find the word \n",
    "whose vector is most similar\n",
    "to the result of $\\textrm{vec}(c)+\\textrm{vec}(b)-\\textrm{vec}(a)$.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "b98949f5",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:40:18.251109Z",
     "iopub.status.busy": "2023-08-18T19:40:18.250824Z",
     "iopub.status.idle": "2023-08-18T19:40:18.255960Z",
     "shell.execute_reply": "2023-08-18T19:40:18.254917Z"
    },
    "origin_pos": 24,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "def get_analogy(token_a, token_b, token_c, embed):\n",
    "    vecs = embed[[token_a, token_b, token_c]]\n",
    "    x = vecs[1] - vecs[0] + vecs[2]\n",
    "    topk, cos = knn(embed.idx_to_vec, x, 1)\n",
    "    return embed.idx_to_token[int(topk[0])]  # Remove unknown words"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "14092363",
   "metadata": {
    "origin_pos": 25
   },
   "source": [
    "Let's verify the \"male-female\" analogy using the loaded word vectors.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "60790fda",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:40:18.259162Z",
     "iopub.status.busy": "2023-08-18T19:40:18.258879Z",
     "iopub.status.idle": "2023-08-18T19:40:18.283948Z",
     "shell.execute_reply": "2023-08-18T19:40:18.283021Z"
    },
    "origin_pos": 26,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'daughter'"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_analogy('man', 'woman', 'son', glove_6b50d)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6cd090ad",
   "metadata": {
    "origin_pos": 27
   },
   "source": [
    "Below completes a\n",
    "“capital-country” analogy: \n",
    "“beijing”:“china”::“tokyo”:“japan”.\n",
    "This demonstrates \n",
    "semantics in the pretrained word vectors.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "fbec7b81",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:40:18.288005Z",
     "iopub.status.busy": "2023-08-18T19:40:18.287287Z",
     "iopub.status.idle": "2023-08-18T19:40:18.314807Z",
     "shell.execute_reply": "2023-08-18T19:40:18.313940Z"
    },
    "origin_pos": 28,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'japan'"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_analogy('beijing', 'china', 'tokyo', glove_6b50d)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "46df1c76",
   "metadata": {
    "origin_pos": 29
   },
   "source": [
    "For the\n",
    "“adjective-superlative adjective” analogy\n",
    "such as \n",
    "“bad”:“worst”::“big”:“biggest”,\n",
    "we can see that the pretrained word vectors\n",
    "may capture the syntactic information.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "f437bf1d",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:40:18.318473Z",
     "iopub.status.busy": "2023-08-18T19:40:18.317911Z",
     "iopub.status.idle": "2023-08-18T19:40:18.340768Z",
     "shell.execute_reply": "2023-08-18T19:40:18.339656Z"
    },
    "origin_pos": 30,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'biggest'"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_analogy('bad', 'worst', 'big', glove_6b50d)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "43e21c82",
   "metadata": {
    "origin_pos": 31
   },
   "source": [
    "To show the captured notion\n",
    "of past tense in the pretrained word vectors,\n",
    "we can test the syntax using the\n",
    "\"present tense-past tense\" analogy: “do”:“did”::“go”:“went”.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "2ef91645",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:40:18.345390Z",
     "iopub.status.busy": "2023-08-18T19:40:18.344827Z",
     "iopub.status.idle": "2023-08-18T19:40:18.370172Z",
     "shell.execute_reply": "2023-08-18T19:40:18.369021Z"
    },
    "origin_pos": 32,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'went'"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_analogy('do', 'did', 'go', glove_6b50d)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c1996db",
   "metadata": {
    "origin_pos": 33
   },
   "source": [
    "## Summary\n",
    "\n",
    "* In practice, word vectors that are pretrained on large corpora can be applied to downstream natural language processing tasks.\n",
    "* Pretrained word vectors can be applied to the word similarity and analogy tasks.\n",
    "\n",
    "\n",
    "## Exercises\n",
    "\n",
    "1. Test the fastText results using `TokenEmbedding('wiki.en')`.\n",
    "1. When the vocabulary is extremely large, how can we find similar words or complete a word analogy faster?\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ef8be3e7",
   "metadata": {
    "origin_pos": 35,
    "tab": [
     "pytorch"
    ]
   },
   "source": [
    "[Discussions](https://discuss.d2l.ai/t/1336)\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  },
  "required_libs": []
 },
 "nbformat": 4,
 "nbformat_minor": 5
}